import numpy as np  # NumPy是一个用于科学计算的基础包,用于处理大型多维数组和矩阵
import faiss  # FAISS库用于高效的相似度搜索和稠密向量的聚类

# 定义FaissKNeighbors类，用于执行基于FAISS的K近邻搜索
class FaissKNeighbors:
    # 类初始化函数：初始化k值，FAISS资源对象res，以及用于存储数据的索引
    def __init__(self, k=1, res=None):
        self.index = None  # 用于存储训练数据的索引
        self.y = None  # 用于存储训练数据的标签
        self.k = k  # 最近邻个数
        self.res = res  # FAISS GPU资源对象

    # 训练函数：将训练数据加入到FAISS索引中
    def fit(self, X, y):
        # 初始化 self.index 为一个FAISS索引: IndexFlatL2, 该索引使用欧氏距离进行搜索
        if self.res is not None:
            self.index = faiss.IndexFlatL2(X.shape[1])
            self.index = faiss.index_cpu_to_gpu(self.res, 0, self.index)
        else:
            self.index = faiss.IndexFlatL2(X.shape[1])
        # 将训练数据加入到FAISS索引中
        self.index.add(X.astype(np.float32))
        # 初始化 self.y 为传入的 y
        self.y = np.array(y)

    # 预测函数：对新的数据集X进行分类预测
    def predict(self, X):
        # 搜索X中每个向量的k个最近邻
        distances, indices = self.index.search(X.astype(np.float32), self.k)
        # 根据索引获得最近邻的标签
        votes = [self.y[i] for i in indices]
        # 通过投票机制得出最终预测的标签
        predictions = np.array([np.argmax(np.bincount(vote)) for vote in votes])
        return predictions

    # 评分函数：计算预测准确率
    def score(self, X, y_true):
        # 预测
        predictions = self.predict(X)
        # 计算准确率
        accuracy = np.mean(predictions == y_true)
        return accuracy

# 使用示例
if __name__ == "__main__":
    # 创建一些随机数据作为示例
    X_train = np.random.rand(100, 64).astype('float32')
    y_train = np.random.randint(0, 5, 100)
    X_test = np.random.rand(10, 64).astype('float32')
    y_test = np.random.randint(0, 5, 10)

    # 创建模型实例
    knn = FaissKNeighbors(k=3)

    # 训练模型
    knn.fit(X_train, y_train)

    # 预测
    predictions = knn.predict(X_test)

    # 计算准确率
    accuracy = knn.score(X_test, y_test)
    print("Accuracy:", accuracy)